# -*_ coding=utf-8 _*_
# VGGNet
# 2023/4/12
# Project:VGGNet
# Creater:Administrator
# Create time:2023-04-12-16-20
# IDE:PyCharm


import cv2
import os
import random
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
"""
数据集iChallenge-PM（眼疾识别）
"""


def transform_img(img):
    img = cv2.resize(img, (224, 224))
    img = np.transpose(img, (2, 0, 1))
    img = img.astype('float32')
    img = img / 255.
    img = img * 2.0 - 1.0
    return img


def data_loader(datadir, batch_size=10, mode='train'):
    filenames = os.listdir(datadir)

    def reader():
        if mode == 'train':
            random.shuffle(filenames)
        batch_imgs = []
        batch_labels = []
        for name in filenames:
            filepath = os.path.join(datadir, name)
            img = cv2.imread(filepath)
            img = transform_img(img)
            if name[0] == 'H':
                label = 0
            elif name[0] == 'N':
                label = 0
            elif name[0] == 'P':
                label = 1
            else:
                print('Not excepted file name')
                print(name[0])
                exit(-1)
            batch_imgs.append(img)
            batch_labels.append(label)
            if len(batch_imgs) == batch_size:
                imgs_array = np.array(batch_imgs).astype('float32')
                labels_array = np.array(batch_labels).astype(
                    'float32').reshape(-1, 1)
                yield imgs_array, labels_array
                batch_imgs = []
                batch_labels = []
        if len(batch_imgs) > 0:
            imgs_array = np.array(batch_imgs).astype('float32')
            labels_array = np.array(batch_labels).astype('float32').reshape(
                -1, 1)
            yield imgs_array, labels_array

    return reader


def valid_data_loader(datadir, csvfile, batch_size=10, mode='valid'):
    filelists = open(csvfile).readlines()

    def reader():
        batch_imgs = []
        batch_labels = []
        for line in filelists[1:]:
            line = line.strip().split(',')
            name = line[1]
            label = int(line[2])
            filepath = os.path.join(datadir, name)
            img = cv2.imread(filepath)
            img = transform_img(img)
            batch_imgs.append(img)
            batch_labels.append(label)
            if len(batch_imgs) == batch_size:
                imgs_array = np.array(batch_imgs).astype('float32')
                labels_array = np.array(batch_labels).astype(
                    'float32').reshape(-1, 1)
                yield imgs_array, labels_array
                batch_imgs = []
                batch_labels = []
        if len(batch_imgs) > 0:
            imgs_array = np.array(batch_imgs).astype('float32')
            labels_array = np.array(batch_labels).astype('float32').reshape(
                -1, 1)
            yield imgs_array, labels_array

    return reader


DATADIR = './res/PALM-Training400'
DATADIR2 = './res/PALM-Validation400'
CSCVFILE = './res/label.csv'


# 定义LeNet的网络结构
class LeNet(fluid.dygraph.Layer):
    def __init__(self, name_scope, num_classes=1):
        super(LeNet, self).__init__(name_scope)

        self.conv1 = Conv2D(num_channels=3,
                            num_filters=6,
                            filter_size=5,
                            act='sigmoid')
        self.pool1 = Pool2D(pool_size=2, pool_stride=2, pool_type='max')
        self.conv2 = Conv2D(num_channels=6,
                            num_filters=16,
                            filter_size=5,
                            act='sigmoid')
        self.pool2 = Pool2D(pool_size=2, pool_stride=2, pool_type='max')
        self.conv3 = Conv2D(num_channels=16,
                            num_filters=120,
                            filter_size=4,
                            act='sigmoid')
        self.fc1 = Linear(input_dim=300000, output_dim=64, act='sigmoid')
        self.fc2 = Linear(input_dim=64, output_dim=num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = fluid.layers.reshape(x, [x.shape[0], -1])
        x = self.fc1(x)
        x = self.fc2(x)
        return x


# 定义AlexNet网络结构
class AlexNet(fluid.dygraph.Layer):
    def __init__(self, name_scope, num_classes=1):
        super(AlexNet, self).__init__(name_scope)
        name_scope = self.full_name
        self.conv1 = Conv2D(num_channels=3,
                            num_filters=96,
                            filter_size=11,
                            stride=4,
                            padding=5,
                            act='relu')
        self.pool1 = Pool2D(pool_size=2, pool_stride=2, pool_type='max')
        self.conv2 = Conv2D(num_channels=96,
                            num_filters=256,
                            filter_size=5,
                            stride=1,
                            padding=2,
                            act='relu')
        self.pool2 = Pool2D(pool_size=2, pool_stride=2, pool_type='max')
        self.conv3 = Conv2D(num_channels=256,
                            num_filters=384,
                            filter_size=3,
                            stride=1,
                            padding=1,
                            act='relu')
        self.conv4 = Conv2D(num_channels=384,
                            num_filters=384,
                            filter_size=3,
                            stride=1,
                            padding=1,
                            act='relu')
        self.conv5 = Conv2D(num_channels=384,
                            num_filters=256,
                            filter_size=3,
                            stride=1,
                            padding=1,
                            act='relu')
        self.pool5 = Pool2D(pool_size=2, pool_stride=2, pool_type='max')

        self.fc1 = Linear(input_dim=12544, output_dim=4096, act='relu')
        self.drop_ratio1 = 0.5
        self.fc2 = Linear(input_dim=4096, output_dim=4096, act='relu')
        self.drop_ratio2 = 0.5
        self.fc3 = Linear(input_dim=4096, output_dim=num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.pool5(x)
        x = fluid.layers.reshape(x, [x.shape[0], -1])
        x = self.fc1(x)
        x = fluid.layers.dropout(x, self.drop_ratio1)
        x = self.fc2(x)
        x = fluid.layers.dropout(x, self.drop_ratio2)
        x = self.fc3(x)
        return x


# 定义训练过程
def train(model):
    with fluid.dygraph.guard():
        print("---- start training ----")
        model.train()
        epoch_num = 5
        opt = fluid.optimizer.Momentum(learning_rate=0.001,
                                       momentum=0.9,
                                       parameter_list=model.parameters())
        train_loader = data_loader(DATADIR, batch_size=10, mode='train')
        valid_loader = valid_data_loader(DATADIR2, CSCVFILE)
        for epoch in range(epoch_num):
            for batch_id, data in enumerate(train_loader()):
                x_data, y_data = data
                img = fluid.dygraph.to_variable(x_data)
                label = fluid.dygraph.to_variable(y_data)
                logits = model(img)
                loss = fluid.layers.sigmoid_cross_entropy_with_logits(
                    logits, label)
                avg_loss = fluid.layers.mean(loss)

                if batch_id % 10 == 0:
                    print("epoch: {}, batch_id: {}, loss is: {}".format(
                        epoch, batch_id, avg_loss.numpy()))
                avg_loss.backward()
                opt.minimize(avg_loss)
                model.clear_gradients()

            model.eval()
            accuracies = []
            losses = []
            for batch_id, data in enumerate(valid_loader()):
                x_data, y_data = data
                img = fluid.dygraph.to_variable(x_data)
                label = fluid.dygraph.to_variable(y_data)
                logits = model(img)
                pred = fluid.layers.sigmoid(logits)
                loss = fluid.layers.sigmoid_cross_entropy_with_logits(
                    logits, label)
                pred2 = pred * (-1.0) + 1.0
                pred = fluid.layers.concat([pred2, pred], axis=1)
                acc = fluid.layers.accuracy(
                    pred, fluid.layers.cast(label, dtype='int64'))
                accuracies.append(acc.numpy())
                losses.append(loss.numpy())
            print("[validation accuracy/loss: {}/{}]".format(
                np.mean(accuracies), np.mean(losses)))
            model.train()
        fluid.save_dygraph(model.state_dict(), './result/iChallengePM')
        fluid.save_dygraph(opt.state_dict(), './result/iChallengePM')


if __name__ == "__main__":
    with fluid.dygraph.guard():
        # model = LeNet("LeNet")
        model = AlexNet("AlexNet")
    train(model)

